{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Gated Recurrent Units (GRU) 门控循环单元",
   "id": "42389cce5a0c9cb1"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-30T04:25:20.804790Z",
     "start_time": "2024-08-30T04:25:15.698420Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l\n",
    "\n",
    "# 加载数据\n",
    "batch_size, num_steps = 32, 35\n",
    "train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)"
   ],
   "id": "20696bc1a45f8fda",
   "execution_count": 1,
   "outputs": []
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-30T04:28:22.159605Z",
     "start_time": "2024-08-30T04:28:22.137985Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 初始化模型参数\n",
    "def get_params(vocab_size, num_hiddens, device):\n",
    "    num_inputs = num_outputs = vocab_size\n",
    "\n",
    "    def normal(shape):\n",
    "        return torch.randn(size=shape, device=device) * 0.01\n",
    "\n",
    "    # 方便更新门、重置门、候选隐状态参数的赋值\n",
    "    def three():\n",
    "        return (normal((num_inputs, num_hiddens)),\n",
    "                normal((num_hiddens, num_hiddens)),\n",
    "                torch.zeros(num_hiddens, device=device))\n",
    "\n",
    "    W_xz, W_hz, b_z = three()  # 更新门参数\n",
    "    W_xr, W_hr, b_r = three()  # 重置门参数\n",
    "    W_xh, W_hh, b_h = three()  # 候选隐状态参数\n",
    "\n",
    "    # 输出层参数\n",
    "    W_hq = normal((num_hiddens, num_outputs))\n",
    "    b_q = torch.zeros(num_outputs, device=device)\n",
    "\n",
    "    # 附加梯度\n",
    "    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]\n",
    "    for param in params:\n",
    "        param.requires_grad_(True)\n",
    "\n",
    "    return params"
   ],
   "id": "870e604bc38d530e",
   "execution_count": 2,
   "outputs": []
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-30T04:29:31.416896Z",
     "start_time": "2024-08-30T04:29:31.406702Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 隐状态的初始化函数, 返回一个形状为（批量大小，隐藏单元个数）\n",
    "def init_gru_state(batch_size, num_hiddens, device):\n",
    "    return (torch.zeros((batch_size, num_hiddens), device=device),)"
   ],
   "id": "d4804242cd991796",
   "execution_count": 3,
   "outputs": []
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-30T04:34:41.334342Z",
     "start_time": "2024-08-30T04:34:41.328103Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 定义门控循环单元模型\n",
    "def gru(inputs, state, params):\n",
    "    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params  # 参数\n",
    "    H, = state  # 初始隐状态\n",
    "    outputs = []  # 输出\n",
    "\n",
    "    for X in inputs:  # 每num_steps拿出一个输入X\n",
    "        Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)  # 更新门  \n",
    "        R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)  # 重置门\n",
    "        H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)  # 候选隐状态\n",
    "        H = Z * H + (1 - Z) * H_tilda  # 下一个状态\n",
    "        Y = H @ W_hq + b_q  # 输出\n",
    "        outputs.append(Y)\n",
    "        # @ 按照矩阵乘法，* 按照元素相乘\n",
    "\n",
    "    return torch.cat(outputs, dim=0), (H,)"
   ],
   "id": "120b40bd4a8922f7",
   "execution_count": 4,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# 训练与预测\n",
    "vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()\n",
    "num_epochs, lr = 500, 1\n",
    "model = d2l.RNNModelScratch(\n",
    "    len(vocab),\n",
    "    num_hiddens,\n",
    "    device,\n",
    "    get_params,\n",
    "    init_gru_state,\n",
    "    gru\n",
    ")\n",
    "\n",
    "d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)"
   ],
   "id": "17643bae3af035f8",
   "execution_count": 6,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## 简洁实现",
   "id": "b85350c71dd20392"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "num_inputs = vocab_size\n",
    "gru_layer = nn.GRU(num_inputs, num_hiddens)\n",
    "model = d2l.RNNModel(gru_layer, len(vocab))\n",
    "model = model.to(device)\n",
    "d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)"
   ],
   "id": "deb128431c7c800a",
   "execution_count": 9,
   "outputs": []
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-30T04:42:04.770663Z",
     "start_time": "2024-08-30T04:42:04.768872Z"
    }
   },
   "cell_type": "code",
   "source": "",
   "id": "f1974c0ac91c35dc",
   "execution_count": 8,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "code",
   "execution_count": null,
   "source": "",
   "id": "f08daf08b3a14bcb",
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
